""" This script does the optimization. When calling this should call with
    parameters MUST be set in the following form: ["time period", "prob exit"]
    (otherwise defaults to: time_period = 'month', do_myopic_optimization = False)

"""

if __name__ == '__main__':
    #%% IMPORTS -------------------------------------------------------------------------
    import pandas as pd
    import scipy.optimize
    import sys
    import copy
    import multiprocessing as mp

    from src.models_new import simulation
    from src.run_scripts import utils

    #%% SET THE PARAMETERS --------------------------------------------------------------
    # Add in the other parameters
    do_myopic_optimization = False
    do_optimization = True
    do_local_optimization = False
    save_results = True
    use_myopic = False
    workers = 60
    mri_max = 2.15

    # Read in external parameters:
    p_exit_runs = [0.75, 0.95]  # [0.5, 0.6, 0.7, 0.75, 0.8, 0.9, 0.95, 0.99]
    try:
        time_period = sys.argv[1]
        input_1 = copy.deepcopy(sys.argv[1])
        if time_period == 'month_local':
            print('With local optimizer')
            time_period = 'month'
            beta = 0.99
            tau_multiplier = 1
            try:
                p_exit_index = int(sys.argv[2])
                p_exit = p_exit_runs[p_exit_index]
                do_myopic_optimization = False
            except:
                p_exit = 1.0
                do_myopic_optimization = True
            bootstrap_seed = ''
        elif time_period == 'month':
            p_exit_index = int(sys.argv[2])
            p_exit = p_exit_runs[p_exit_index]
            beta = 0.99
            do_myopic_optimization = False
            tau_multiplier = 1
            bootstrap_seed = ''
            print(f'Doing nonmyopic optimization, seed: {p_exit}')
        elif time_period == 'fortnight':
            p_exit_index = int(sys.argv[2])
            beta = 0.995
            tau_multiplier = 2
            do_myopic_optimization = True  # This is a robustness test
            p_exit = 1.0
            bootstrap_seed = ''
        elif time_period == 'bootstrap':
            beta = 0.99
            tau_multiplier = 1
            do_myopic_optimization = True
            p_exit = 1.0
            bootstrap_seed = int(sys.argv[2])

    except:
        print('Revert to default time period/prob exit')
        time_period = 'month'
        input_1 = 'month'
        beta = 0.99
        tau_multiplier = 1
        p_exit = 1.0
        do_myopic_optimization = True
        bootstrap_seed = ''

    if not do_myopic_optimization:
        path_save = f'./models/smm/params_smm_with_diff_new_myopic_{p_exit}.csv'

    elif time_period == 'bootstrap':
        path_save = f'./models/smm/bootstrap_output/params_smm_with_diff_new{bootstrap_seed}.csv'

    elif time_period == 'month':
        path_save = f'./models/smm/params_smm_with_diff_new{bootstrap_seed}.csv'

    print(path_save)
    print(f"Prob exit: {p_exit}, time period: {time_period}")

    bounds = {
        'm_0_low': (0.04, 0.2 / tau_multiplier),
        'm_1_low': (-0.15, 0.00),
        'm_0_mid': (0.04, 0.2 / tau_multiplier),
        'm_1_mid': (0, 0.05),
        'm_0_high': (0.01, 0.2 / tau_multiplier),
        'm_1_high': (0.0, 0.1),
        'm_2': (0.00, 0.1),
        'a_1_low': (0.1, 2.0),
        'a_1_mid': (0.1, 2.0),
        'a_1_high': (0.5, 40.0),
        'd_0': (0, 40.0 / tau_multiplier),
        'd_1': (2, 30.0 / tau_multiplier),
        'mu_0': (0.4, 0.85),
        'sigma_0': (0.2, 1.6),
        'gamma': (0, 5.0),
        'gamma_negative': (0.0, 30.0),
        'p_2': (0.6, 0.75),
        'p_3': (0.1, 0.25),
        'eta': (0.3, 0.45)
    }

    #%% READ IN THE INPUTS --------------------------------------------------------------
    if bootstrap_seed == '':
        data, delta, rho, c, weights = utils.read_in_data(
            path_moments_data='./models/smm_input/moments_empirical.csv',
            path_n_rigs=f"./models/first_stage/n_rigs",
            path_surplus_components="./models/surplus/surplus_components",
            path_surplus_grid='./models/surplus/surplus_grid_2_low_month.npy',
            path_df_state='./data_py/processed/states',
            path_delta='./models/smm_input/delta.csv',
            path_entry_cost='./models/smm_input/entry_cost.csv',
            path_rho='./models/smm_input/rho.csv',
            path_df_contracts='./data_py/processed/contracts_final.csv',
            path_price_match_values='./models/price_match/price_match_values',
            path_coefs_data='./models/smm_input/coefs_data',
            path_prob_match_predict_contracts='./models/robustness/prob_match_predict_contracts',
            path_prob_match_predict='./models/robustness/prob_match_predict',
            time_period=time_period,
            p_exit=p_exit,
            use_myopic=do_myopic_optimization,
            bootstrap_seed=''
        )
    else:
        data, delta, rho, c, weights = utils.read_in_data(
            path_moments_data=f'./models/bootstrap/processed/moments_empirical{bootstrap_seed}.csv',
            path_n_rigs=f"./models/first_stage/n_rigs",
            path_surplus_components="./models/surplus/surplus_components",
            path_surplus_grid='./models/surplus/surplus_grid_2_low_month.npy',
            path_df_state='./data_py/processed/states',  # only used for g?
            path_delta='./models/smm_input/delta.csv',
            path_entry_cost='./models/smm_input/entry_cost.csv',
            path_rho='./models/smm_input/rho.csv',
            path_df_contracts=f'./models/bootstrap/processed/df_contracts_month{bootstrap_seed}.csv',
            path_price_match_values='./models/bootstrap/price_match/price_match_values',
            path_coefs_data=f'./models/bootstrap/processed/coefs_data',
            path_prob_match_predict_contracts='',
            path_prob_match_predict='',
            time_period='month',
            p_exit=1.0,
            use_myopic=do_myopic_optimization,
            bootstrap_seed=bootstrap_seed
        )

    #%% ADD IN FIXED PARAMETERS ---------------------------------------------------------
    params_fixed = {
        'c': c[0],
        'beta': beta,
        'delta': delta[0],
        'rho_0': rho['rho_0'],
        'rho_1': rho['rho_1'],
        'rho_2': rho['rho_2'],
        'rho_3': rho['rho_3'],

        # Stuff below is not important but needed for a more general model...
        'a_0_low': 1.0,
        'mu_1': 0.0,
        'sigma_1': 0.0,
        'weight_lambda': 1.0,
        #'gamma_negative': 0.0
    }

    #%% (POTENTIALLY) READ IN THE MYOPIC PARAMS -----------------------------------------
    if not do_myopic_optimization:
        params_fixed['delta'] = pd.read_csv(
            f'./models/robustness/delta_adjusted_{p_exit}.csv',
            index_col=[0]
        )['0'].loc[0]
        print(f"DOING: {p_exit_index}, {p_exit}, delta: {params_fixed['delta']}")

    #%% DO THE SIMULATION ---------------------------------------------------------------
    args_config = {
        'x_names': list(bounds.keys()),
        'params_fixed': params_fixed,
        'weights': weights,
        'mri_max': mri_max,
        'verbose': False,
        'verbose_output': False,
        'value_zero': False,
        'entry_prob_by_tau_by_ym': None,
        'tau_multiplier': tau_multiplier,
        'well_target': True
    }
    args_by_names = {**args_config, **data}

    #%%
    if do_optimization:
        if ((time_period in ['month', 'fortnight']) & (input_1 != 'month_local')):
            pool = mp.Pool(workers, maxtasksperchild=200)
            output = scipy.optimize.differential_evolution(
                func=simulation.run_steps,
                seed=8,
                bounds=tuple(bounds.values()),
                args=(args_by_names,),
                updating='deferred',
                workers=pool.map
            )
        else:
            print("Doing optimization - nelder mead")
            if input_1 == 'month_local':
                if p_exit == 1.0:
                    params_init = pd.read_csv(
                        './models/smm_input/params_smm_with_diff_new_init.csv',
                        index_col=[0]
                    )
                else:
                    params_init = pd.read_csv(
                        f'./models/smm_input/params_smm_with_diff_new_myopic_{p_exit}.csv',
                        index_col=[0]
                    )
            else:
                params_init = pd.read_csv(
                    './models/smm_input/params_smm_with_diff_new_init.csv',
                    index_col=[0]
                )
            print("Initial params:")
            print(params_init)
            b = params_init.loc[list(bounds.keys())]
            output = scipy.optimize.minimize(
                fun=simulation.run_steps,
                x0=b,
                bounds=tuple(bounds.values()),
                args=(args_by_names,),
                method='Nelder-Mead'
            )

        params_smm = dict(zip(args_by_names['x_names'], output.x))
        params_final = {**params_smm, **params_fixed}
        if save_results is True:
            pd.Series(params_final).to_csv(path_save)

        args_by_names['verbose'] = True
        simulation.run_steps(x=list(output.x), args_by_names=args_by_names)
        print(pd.Series(dict(zip(args_by_names['x_names'], output.x))))
